#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <stdint.h>
#include <stdbool.h>
#include <string.h>
#include <time.h>

uint32_t create_mask(uint8_t from, uint8_t to)
{
	uint32_t r = 0;

	for (uint8_t i = from; i <= to; i++)
		r |= 1 << i;

	return r;
}

uint8_t *allocate_uint_bytes(uint32_t uint)
{
	uint8_t *uints = malloc(4);

	for (size_t i = 0; i < 4; i++)
		uints[3 - i] = (uint & create_mask(i * 8, (i + 1) * 8)) >> (i * 8);
	
	return uints;
}

// little endian
// incoherently unreadable, but does get the job done
char *uint_to_binary_string(uint32_t uint, size_t bytes)
{
	char *buf = malloc(33);
	uint8_t *uints = allocate_uint_bytes(uint);

	for (size_t i = 0; i < bytes; i++)
		for (size_t j = 0; j < 8; j++)
			buf[i * 8 + j] = (((1 << (7 - j)) & uints[i + 4 - bytes]) > 0) ? '1' : '0';
	buf[bytes * 8] = '\0';

	//printf("%u\n%x %x %x %x\n%s\n", uint, uints[0], uints[1], uints[2], uints[3], buf);
	free(uints);

	return buf;
}

char *binary_string_n_bits(char *str, size_t n)
{
	size_t i;
	for (i = 0; i < n; i++)
		str[i] = str[strlen(str) - n + i];
	str[i] = '\0';
	return str;
}

typedef enum {
	LUI = 0, AUIPC, JAL, JALR, BRANCH, LOAD, STORE, OP_IMM, OP, NOP, INVALID, OPCODE_SIZE
} opcode;

typedef enum {
	ALU_NOP = 0, ALU_ADD, ALU_SUB, ALU_SLL, ALU_SLT, ALU_SLTU, ALU_XOR, ALU_SRL, ALU_SRA, ALU_OR, ALU_AND, ALU_OPCODE_SIZE
} alu_opcode;
char *alu_opcodes[ALU_OPCODE_SIZE] = {
	"ALU_NOP", "ALU_ADD", "ALU_SUB", "ALU_SLL", "ALU_SLT", "ALU_SLTU", "ALU_XOR", "ALU_SRL", "ALU_SRA", "ALU_OR", "ALU_AND"
};

typedef enum {
	MEM_B = 0, MEM_BU, MEM_H, MEM_HU, MEM_W, MEMTYPE_SIZE
} memtype;
char *memtypes[MEMTYPE_SIZE] = {
	"MEM_B", "MEM_BU", "MEM_H", "MEM_HU", "MEM_W"
};

typedef enum {
	BR_NOP = 0, BR_BR, BR_CND, BR_CNDI, BRANCHTYPE_SIZE
} branchtype;
char *branchtypes[BRANCHTYPE_SIZE] = {
	"BR_NOP", "BR_BR", "BR_CND", "BR_CNDI"
};

typedef enum {
	WBS_ALU = 0, WBS_MEM, WBS_OPC, WBTYPE_SIZE
} wbtype;
char *wbtypes[WBTYPE_SIZE] = {
	"WBS_ALU", "WBS_MEM", "WBS_OPC"
};

typedef struct {
	// alu op
	alu_opcode aluop;
	bool alusrc1;
	bool alusrc2;
	bool alusrc3;
	uint8_t rs1;
	uint8_t rs2;
	uint32_t readdata1;
	uint32_t readdata2;
	uint32_t imm;
} exec_op_t;

typedef struct {
	// mem op
	branchtype branch;
	bool memread;
	bool memwrite;
	memtype memtype;
} mem_op_t;

typedef struct {
	// wb op
	uint8_t rd;
	bool write;
	wbtype src;
} wb_op_t;

typedef struct {
	bool write;
	uint8_t reg;
	uint32_t data;
} reg_write_t;

typedef struct {
	bool stall;
	bool flush;
	uint16_t pc_in;
	// alu op
	exec_op_t exec_op;
	// mem op
	mem_op_t mem_op;
	// wb op
	wb_op_t wb_op;
	// regwrites
	reg_write_t mem;
	reg_write_t wr;
} input_t;

typedef struct {
	uint16_t pc_old_out;
	uint16_t pc_new_out;
	uint32_t aluresult;
	char zero;
	uint32_t wrdata;
	// alu op
	exec_op_t exec_op;
	// mem op
	mem_op_t mem_op;
	// wb op
	wb_op_t wb_op;
} output_t;

int main(void) {
	FILE *file_output;
	FILE *file_input;
	file_output = fopen("./testdata/output.txt", "w");
	file_input = fopen("./testdata/input.txt", "w");
	srand(time(NULL));

	int n = 50;
	for (size_t i = 0; i < 50; i++) {
		// create random input
		input_t input;
		input.stall = false;
		input.flush = false;
		input.pc_in = rand();
		input.exec_op.aluop = rand() % ALU_OPCODE_SIZE;
		input.exec_op.alusrc1 = rand() % 2 ? true : false;
		input.exec_op.alusrc2 = rand() % 2 ? true : false;
		input.exec_op.alusrc3 = rand() % 2 ? true : false;
		input.exec_op.rs1 = 0;
		input.exec_op.rs2 = 0;
		input.exec_op.readdata1 = rand();
		input.exec_op.readdata2 = rand();
		input.exec_op.imm = rand();
		input.mem_op.branch = rand() % BRANCHTYPE_SIZE;
		input.mem_op.memread = rand() % 2 ? true : false;
		input.mem_op.memwrite = !input.mem_op.memread;
		input.mem_op.memtype = rand() % MEMTYPE_SIZE;
		input.wb_op.rd = rand();
		input.wb_op.write = rand() % 2 ? true : false;
		input.wb_op.src = rand() % WBTYPE_SIZE;
		input.mem.write = rand() % 2 ? true : false;
		input.mem.reg = rand();
		input.mem.data = rand();
		input.wr.write = rand() % 2 ? true : false;
		input.wr.reg = rand();
		input.wr.data = rand();

		// pre determined output
		output_t output;
		output.pc_old_out = input.pc_in;
		output.pc_new_out = (input.exec_op.alusrc1 && input.exec_op.alusrc2 && input.exec_op.alusrc3) ? ((input.exec_op.readdata1 & 0b1111111111111111) + (input.exec_op.imm & 0b1111111111111111)) : (input.pc_in + (input.exec_op.alusrc3 ? input.exec_op.imm : 0));
		output.wrdata = input.exec_op.readdata2;
		//output.exec_op = input.exec_op;
		output.exec_op.aluop = input.exec_op.aluop;
		output.exec_op.alusrc1 = input.exec_op.alusrc1;
		output.exec_op.alusrc2 = input.exec_op.alusrc2;
		output.exec_op.alusrc3 = input.exec_op.alusrc3;
		output.exec_op.rs1 = input.exec_op.rs1;
		output.exec_op.rs2 = input.exec_op.rs2;
		output.exec_op.readdata1 = input.exec_op.readdata1;
		output.exec_op.readdata2 = input.exec_op.readdata2;
		output.exec_op.imm = input.exec_op.imm;
		output.mem_op = input.mem_op;
		output.wb_op = input.wb_op;
		uint32_t alu_in1 = input.exec_op.alusrc1 ? input.pc_in : input.exec_op.readdata1;
		uint32_t alu_in2 = input.exec_op.alusrc2 ? input.exec_op.imm : input.exec_op.readdata2;
		switch (input.exec_op.aluop) {
			case ALU_NOP:
				output.aluresult = alu_in2;
				break;
			case ALU_ADD:
				output.aluresult = alu_in1 + alu_in2;
				break;
			case ALU_SUB:
				output.aluresult = alu_in1 - alu_in2;
				break;
			case ALU_SLL:
				output.aluresult = alu_in1 << (alu_in2 & 0b11111);
				break;
			case ALU_SLT:
				output.aluresult = (((int32_t) alu_in1) < ((int32_t) alu_in2)) ? 1 : 0;
				break;
			case ALU_SLTU:
				output.aluresult = (alu_in1 < alu_in2) ? 1 : 0;
				break;
			case ALU_XOR:
				output.aluresult = alu_in1 ^ alu_in2;
				break;
			case ALU_SRL:
				output.aluresult = alu_in1 >> (alu_in2 & 0b11111);
				break;
			case ALU_SRA:
				output.aluresult = (uint32_t) ((int32_t) alu_in1) >> (alu_in2 & 0b11111);
				break;
			case ALU_OR:
				output.aluresult = alu_in1 | alu_in2;
				break;
			case ALU_AND:
				output.aluresult = alu_in1 & alu_in2;
				break;
			default:
				output.aluresult = 0;
				break;
		}
		output.zero = '-';
		if (input.exec_op.aluop == ALU_SUB)
			output.zero = (alu_in1 == alu_in2) ? '1' : '0';
		else if (input.exec_op.aluop == ALU_SLT || input.exec_op.aluop == ALU_SLTU)
			output.zero = !((output.aluresult & 1) > 0) ? '1' : '0';

		// output results
		fprintf(
			file_input,
			"# test case %zu\n# stall\n%d\n# flush\n%d\n# pc\n%s\n# aluop\n%s\n# alusrc1\n%d\n# alusrc2\n%d\n# alusrc3\n%d\n# rs1\n%s\n# rs2\n%s\n# readdata1\n%s\n# readdata2\n%s\n# imm\n%s\n# branchtype\n%s\n# memread\n%d\n# memwrite\n%d\n# memtype\n%s\n# wb rd\n%s\n# wb write\n%d\n# wb src\n%s\n# reg write mem write\n%d\n# reg write mem reg\n%s\n# reg write mem data\n%s\n# reg write wr write\n%d\n# reg write wr reg\n%s\n# reg write wr data\n%s\n",
			i,
			input.stall ? 1 : 0,
			input.flush ? 1 : 0,
			uint_to_binary_string(input.pc_in, 2),
			alu_opcodes[input.exec_op.aluop],
			input.exec_op.alusrc1 ? 1 : 0,
			input.exec_op.alusrc2 ? 1 : 0,
			input.exec_op.alusrc3 ? 1 : 0,
			binary_string_n_bits(uint_to_binary_string(output.exec_op.rs1 & 0b11111, 1), 5),
			binary_string_n_bits(uint_to_binary_string(output.exec_op.rs2 & 0b11111, 1), 5),
			uint_to_binary_string(input.exec_op.readdata1, 4),
			uint_to_binary_string(input.exec_op.readdata2, 4),
			uint_to_binary_string(input.exec_op.imm, 4),
			branchtypes[input.mem_op.branch],
			input.mem_op.memread ? 1 : 0,
			input.mem_op.memwrite ? 1 : 0,
			memtypes[input.mem_op.memtype],
			binary_string_n_bits(uint_to_binary_string(input.wb_op.rd & 0b11111, 1), 5),
			input.wb_op.write ? 1 : 0,
			wbtypes[input.wb_op.src],
			input.mem.write ? 1 : 0,
			binary_string_n_bits(uint_to_binary_string(input.mem.reg & 0b11111, 1), 5),
			uint_to_binary_string(input.mem.data, 4),
			input.wr.write ? 1 : 0,
			binary_string_n_bits(uint_to_binary_string(input.wr.reg & 0b11111, 1), 5),
			uint_to_binary_string(input.wr.data, 4)
		);
		if (i < (size_t) n - 1)
			fprintf(file_input,"\n");

		fprintf(
			file_output,
			"# test case %zu\n# pc old out\n%s\n# pc new out\n%s\n# aluresult\n%s\n# zero\n%c\n# wrdata\n%s\n# aluop\n%s\n# alusrc1\n%d\n# alusrc2\n%d\n# alusrc3\n%d\n# rs1\n%s\n# rs2\n%s\n# readdata1\n%s\n# readdata2\n%s\n# imm\n%s\n# branchtype\n%s\n# memread\n%d\n# memwrite\n%d\n# memtype\n%s\n# wb rd\n%s\n# wb write\n%d\n# wb src\n%s\n",
			i,
			uint_to_binary_string(output.pc_old_out, 2),
			uint_to_binary_string(output.pc_new_out, 2),
			uint_to_binary_string(output.aluresult, 4),
			output.zero,
			uint_to_binary_string(output.wrdata, 4),
			alu_opcodes[output.exec_op.aluop],
			output.exec_op.alusrc1 ? 1 : 0,
			output.exec_op.alusrc2 ? 1 : 0,
			output.exec_op.alusrc3 ? 1 : 0,
			binary_string_n_bits(uint_to_binary_string(output.exec_op.rs1 & 0b11111, 1), 5),
			binary_string_n_bits(uint_to_binary_string(output.exec_op.rs2 & 0b11111, 1), 5),
			uint_to_binary_string(output.exec_op.readdata1, 4),
			uint_to_binary_string(output.exec_op.readdata2, 4),
			uint_to_binary_string(output.exec_op.imm, 4),
			branchtypes[output.mem_op.branch],
			input.mem_op.memread ? 1 : 0,
			input.mem_op.memwrite ? 1 : 0,
			memtypes[input.mem_op.memtype],
			binary_string_n_bits(uint_to_binary_string(output.wb_op.rd & 0b11111, 1), 5),
			output.wb_op.write ? 1 : 0,
			wbtypes[output.wb_op.src]
		);
		if (i < (size_t) n - 1)
			fprintf(file_output, "\n");
	}
	
	fclose(file_output);
	fclose(file_input);

	return EXIT_SUCCESS;
}
